#################################################################################
# Copyright (c) 2018-2021, Texas Instruments Incorporated - http://www.ti.com
# All Rights Reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
#   list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
#   this list of conditions and the following disclaimer in the documentation
#   and/or other materials provided with the distribution.
#
# * Neither the name of the copyright holder nor the names of its
#   contributors may be used to endorse or promote products derived from
#   this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################

"""
Focal_loss:
https://arxiv.org/abs/1708.02002

Lovasz-Softmax and Jaccard hinge loss in PyTorch
Maxim Berman 2018 ESAT-PSI KU Leuven (MIT License)

"""

import torch
import torch.nn as nn
import torch.nn.functional as F


class FocalLoss(nn.Module):
    def __init__(self, *args, alpha=0.5, gamma=2, weight=None, ignore_index=255, **kwargs):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.weight = weight
        self.ignore_index = ignore_index
        self.ce_fn = nn.CrossEntropyLoss(weight=self.weight, ignore_index=self.ignore_index)
        self.is_avg = False

    def forward(self, preds, labels):
        logpt = -self.ce_fn(preds, labels)
        pt = torch.exp(logpt)
        loss = -((1 - pt) ** self.gamma) * self.alpha * logpt
        return loss
    def info(self):
        return {'value':'loss', 'name':'CrossEntropyLoss', 'is_avg':self.is_avg}
    def clear(self):
        return
    @classmethod
    def args(cls):
        return ['weight']


class FocalLoss2D(nn.Module):
    """
    Focal Loss, which is proposed in:
        "Focal Loss for Dense Object Detection (https://arxiv.org/abs/1708.02002v2)"
    """
    def __init__(self, ignore_index=255, alpha=0.25, gamma=2, weight=None, size_average=True):
        """
        Loss(x, class) = - \alpha (1-softmax(x)[class])^gamma \log(softmax(x)[class])

        :param ignore_index:  (int) ignore label
        :param alpha:         (1D Tensor or Variable) the scalar factor
        :param gamma:         (float) gamma > 0;
                                      reduces the relative loss for well-classified examples (probabilities > .5),
                                      putting more focus on hard, mis-classified examples
        :param size_average:  (bool): By default, the losses are averaged over observations for each mini-batch.
                                      If the size_average is set to False, the losses are
                                      instead summed for each mini-batch.
        """
        super(FocalLoss2D, self).__init__()
        self.weight = None if weight is None else self.register_buffer('weight', torch.FloatTensor(weight))
        self.alpha = alpha
        self.gamma = gamma
        self.one_hot = None
        self.ignore_index = ignore_index
        self.size_average = size_average
        self.is_avg=False

    def forward(self, input_img, cls_preds, cls_targets):
        """

        :param cls_preds:    (n, c, h, w)
        :param cls_targets:  (n, h, w)
        :return:
        """
        assert not cls_targets.requires_grad
        cls_targets = cls_targets.long()
        assert cls_targets.dim() == 4
        assert cls_preds.size(0) == cls_targets.size(0), "{0} vs {1} ".format(cls_preds.size(0), cls_targets.size(0))
        assert cls_preds.size(2) == cls_targets.size(2), "{0} vs {1} ".format(cls_preds.size(2), cls_targets.size(1))
        assert cls_preds.size(3) == cls_targets.size(3), "{0} vs {1} ".format(cls_preds.size(3), cls_targets.size(3))

        n, c, h, w = cls_preds.size()
        self.one_hot = torch.eye(c)
        if cls_preds.is_cuda:
            self.one_hot = self.one_hot.cuda()
        loss_zero = torch.zeros_like(cls_preds[0].view(-1)[0])
        # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
        # 1. target reshape and one-hot encode
        # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
        # 1.1. target: (n*h*w,)
        cls_targets = cls_targets.view(n * h * w, 1)
        target_mask = (cls_targets >= 0) * (cls_targets != self.ignore_index)

        cls_targets = cls_targets[target_mask]
        cls_targets = self.one_hot.index_select(dim=0, index=cls_targets)

        # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
        # 2. compute focal loss for multi-classification
        # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
        # 2.1. The softmax. probs: (n, c, h, w)
        probs = F.softmax(cls_preds, dim=1)
        # 2.2. probs: (n*h*w, c) - contiguous() required if transpose() is used before view().
        probs = probs.transpose(1, 2).transpose(2, 3).contiguous().view(-1, c)
        probs = probs[target_mask.repeat(1, c)]
        probs = probs.view(-1, c)  # (n*h*w, c)

        probs = torch.clamp((probs * cls_targets).sum(1).view(-1, 1), min=1e-8, max=1.0)
        batch_loss = -self.alpha * (torch.pow((1 - probs), self.gamma)) * probs.log()
        # loss = F.nll_loss(probs.log(), target=cls_targets, ignore_index=self.ignore_label, weight=self.weight, size_average=False)
        # batch_loss = (torch.pow((1 - probs), self.gamma)) * loss

        if self.size_average:
            # loss /= mask.sum().data[0]
            sum_val = target_mask.data.sum().float() if target_mask.numel() > 0 else 0
            batch_loss = (batch_loss.sum() / sum_val) if (sum_val != 0) else loss_zero
            # print(sum_val)
        return batch_loss

    def info(self):
        return {'value':'loss', 'name':'FocalLoss2D', 'is_avg':self.is_avg}
    def clear(self):
        return
    @classmethod
    def args(cls):
        return ['weight']

segmentation_focal_loss = FocalLoss2D


def lovasz_grad(gt_sorted):
    """
    Computes gradient of the Lovasz extension w.r.t sorted errors
    See Alg. 1 in paper
    """
    p = len(gt_sorted)
    gts = gt_sorted.sum()
    intersection = gts - gt_sorted.float().cumsum(0)
    union = gts + (1 - gt_sorted).float().cumsum(0)
    jaccard = 1 - intersection / union
    if p > 1:  # cover 1-pixel case
        jaccard[1:p] = jaccard[1:p] - jaccard[0:-1]
    return jaccard


def lovasz_softmax(prob, lbl, ignore_index, only_present):
    """
    Multi-class Lovasz-Softmax loss
      probas: [B, C, H, W] Variable, class probabilities at each prediction (between 0 and 1)
      labels: [B, H, W] Tensor, ground truth labels (between 0 and C - 1)
      only_present: average only on classes present in ground truth
      per_image: compute the loss per image instead of per batch
      ignore: void class labels
    """
    assert prob.ndim == lbl.ndim == 4
    C = prob.shape[1]
    prob = prob.permute(0, 2, 3, 1).contiguous().view(-1, C)  # H * W, C
    lbl = lbl.view(-1)  # H * W
    if ignore_index is not None:
        mask = (lbl != ignore_index)
        if mask.sum() == 0:
            return torch.mean(prob * 0)
        prob = prob[mask]
        lbl = lbl[mask]

    total_loss = 0
    cnt = 0
    for c in range(C):
        fg = (lbl == c).float()  # foreground for class c
        if only_present and fg.sum() == 0:
            continue
        errors = (fg - prob[:, c]).abs()
        errors_sorted, perm = torch.sort(errors, dim=0, descending=True)
        perm = perm.data
        fg_sorted = fg[perm]
        total_loss += torch.dot(errors_sorted, lovasz_grad(fg_sorted))
        cnt += 1
    return total_loss / cnt


class LovaszSoftmax(nn.Module):
    """
    Multi-class Lovasz-Softmax loss
      logits: [B, C, H, W] class logits at each prediction (between -\infty and \infty)
      labels: [B, H, W] Tensor, ground truth labels (between 0 and C - 1)
      ignore_index: void class labels
      only_present: average only on classes present in ground truth
    """
    def __init__(self, ignore_index=255, only_present=True, weight=None):
        super().__init__()
        self.weight = None if weight is None else self.register_buffer('weight', torch.FloatTensor(weight))
        self.ignore_index = ignore_index
        self.only_present = only_present
        self.is_avg = False
        self.per_image_iou = True

    def forward(self, input_img, cls_preds, cls_targets):
        prbs = F.softmax(cls_preds, dim=1)
        total_loss = 0
        batch_size = cls_preds.shape[0]
        if self.per_image_iou:
            for prb, lbl in zip(prbs, cls_targets):
                prb = prb.unsqueeze(0)
                lbl = lbl.unsqueeze(0)
                total_loss += lovasz_softmax(prb, lbl, ignore_index=self.ignore_index, only_present=self.only_present)
        else:
            total_loss = lovasz_softmax(prbs, cls_targets, ignore_index=self.ignore_index, only_present=self.only_present)

        return total_loss / batch_size

    def info(self):
        return {'value':'loss', 'name':'LovaszSoftmax', 'is_avg':self.is_avg}
    def clear(self):
        return
    @classmethod
    def args(cls):
        return ['weight']

segmentation_lovasz_loss = LovaszSoftmax
